#include "RayTransform.hpp"
#include <Math/Math.hpp>

namespace zzz{
float Ft_in(float costheta1,float eta)
{
  if (costheta1>1.0) costheta1=1.0;
  float sintheta1=sqrt(1.0-costheta1*costheta1);
  float sintheta2=(sintheta1 * eta);
  if (sintheta2>1.0) return 0;
  float costheta2=sqrt(1.0-sintheta2*sintheta2);
  float Rs=(eta*costheta1-costheta2)/(eta*costheta1+costheta2);
  Rs=Rs*Rs;
  float Rp=(eta*costheta2-costheta1)/(eta*costheta2+costheta1);
  Rp=Rp*Rp;
  float Ts=1.0-Rs,Tp=1.0-Rp;
  return (Ts+Tp)/2.0;
}

float Ft_out(float costheta2,float eta)
{
  const float eta2=1.0 / eta;
  if (costheta2>1.0) costheta2=1.0;
  float sintheta2=sqrt(1.0-costheta2*costheta2);
  float sintheta1=(sintheta2 * eta2);
  if (sintheta1>1.0) return 0;
  float costheta1=sqrt(1.0-sintheta1*sintheta1);
  float Rs=(eta2*costheta2-costheta1)/(eta2*costheta2+costheta1);
  Rs=Rs*Rs;
  float Rp=(eta2*costheta1-costheta2)/(eta2*costheta1+costheta2);
  Rp=Rp*Rp;
  float Ts=1.0-Rs,Tp=1.0-Rp;
  return (Ts+Tp)/2.0;
}

float Ft_out2(float costheta1,float eta)
{
  const float eta2=1.0 / eta;
  float sintheta1=sqrt(1.0-costheta1*costheta1);
  float sintheta2=(sintheta1 * eta);
  if (sintheta2>1.0) return 0;
  float costheta2=sqrt(1.0-sintheta2*sintheta2);
  float Rs=(eta2*costheta2-costheta1)/(eta2*costheta2+costheta1);
  Rs=Rs*Rs;
  float Rp=(eta2*costheta1-costheta2)/(eta2*costheta1+costheta2);
  Rp=Rp*Rp;
  float Ts=1.0-Rs,Tp=1.0-Rp;
  return (Ts+Tp)/2.0;
}

bool RefractTo(Vector3f inray,Vector3f &outray,float eta,const Vector3f &normal/*=Vector3f(0,0,1)*/)
{//default normal is (0,0,1)
  float cos1=normal.Dot(inray);
  if (cos1>1-EPS)
  {
    outray=-inray;
    return true;
  }
  float sin1=sqrt(1-cos1*cos1);
  float sin2=sin1*eta;
  if (sin2>1) return false;
  float tg1=sin1/cos1;
  inray/=cos1;
  float cos2=sqrt(1.0-sin2*sin2);
  float tg2=sin2/cos2;
  Vector3f offset=normal-inray;
  offset=offset/tg1*tg2;
  outray=-normal+offset;
  outray.Normalize();
  return true;
}

bool RefractFrom(Vector3f outray,Vector3f &inray,float eta,const Vector3f &normal/*=Vector3f(0,0,1)*/)
{//default normal is (0,0,1)
  float cos2=normal.Dot(outray);
  float sin2=sqrt(1-cos2*cos2);
  float sin1=sin2*eta;
  Vector3f offset=normal-outray;
  offset*=eta;
  inray=-normal+offset;
  inray.Normalize();
  return true;
}
}
